package com.proxy.server.frontend.handler;

import com.proxy.common.constant.CapabilitiesType;
import com.proxy.common.constant.ErrorCodeType;
import com.proxy.common.constant.VersionsType;
import com.proxy.common.model.Account;
import com.proxy.common.packet.*;
import com.proxy.common.utils.CharsetUtil;
import com.proxy.common.utils.RandomUtil;
import com.proxy.common.utils.SecurityUtil;
import com.proxy.server.context.AppContext;
import com.proxy.server.frontend.FrontendConnection;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;

import java.security.NoSuchAlgorithmException;
import java.util.List;

/**
 * 用户认证实现类
 *
 * Created by liufish on 16/7/15.
 */
public class AuthHandler {


    /**
     * 服务器端主动发送握手包
     *
     * @param connection
     */
    public static void handshake(final FrontendConnection connection) {

        byte[] frontendSeed = RandomUtil.randomBytes(8);
        byte[] restOfScramble = RandomUtil.randomBytes(12);

        // 保存认证数据
        final byte[] seed = new byte[frontendSeed.length + restOfScramble.length];
        System.arraycopy(frontendSeed, 0, seed, 0, frontendSeed.length);
        System.arraycopy(restOfScramble, 0, seed, frontendSeed.length, restOfScramble.length);

        // 发送握手数据包
        HandshakePacket handshakePacket = new HandshakePacket();
        handshakePacket.packetId = 0;
        handshakePacket.protocolVersion = VersionsType.PROTOCOL_VERSION;
        handshakePacket.serverVersion = VersionsType.SERVER_VERSION;

        //线程id，这里统一用channel 的hashcode来代替
        handshakePacket.threadId = connection.getChannel().hashCode();
        handshakePacket.seed = frontendSeed;
        handshakePacket.serverCapabilities = CapabilitiesType.getServerCapabilities();
        //设定字符集，默认45
        handshakePacket.serverCharsetIndex = AppContext.getInstance().getFrontendProxy().getCharsetIndex();
        handshakePacket.serverStatus = 2;
        handshakePacket.restOfScrambleBuff = restOfScramble;

        connection.getChannel().writeAndFlush(handshakePacket).addListener(new GenericFutureListener() {
            @Override
            public void operationComplete(Future future) throws Exception {
                if (future.isDone()) {
                    //暂时未验证通过
                    connection.setAuth(false);
                    connection.setSeed(seed);
                }
            }
        });
    }


    /**
     * 用户相关认证。
     * @param bin
     * @param connection
     */
    public static void auth(BinaryPacket bin,final FrontendConnection connection) throws Exception {

        if (connection.isAuth()) {
            return;
        }

        //未认证
        final AuthPacket authPacket = new AuthPacket();
        authPacket.read(bin);

        //用户名为空，或者用户不匹配
        if (authPacket.user == null || ! _checkUser(authPacket,connection)) {
            String message = String.format("Access denied for user ' %s'", authPacket.user);
            ErrorPacket error = new ErrorPacket();
            error.packetId = authPacket.packetId + 1;
            error.errNo = ErrorCodeType.ER_ACCESS_DENIED_ERROR;
            error.message = message.getBytes(CharsetUtil.getCharset(AppContext.getInstance().getFrontendProxy().getCharsetIndex()));
            connection.write(error);

            return;
        }

        //密码
        if (!_checkPassword(authPacket, connection.getSeed())) {
            String message = String.format("Access denied for user ' %s'", authPacket.user);
            ErrorPacket error = new ErrorPacket();
            error.packetId = authPacket.packetId +1;
            error.errNo = ErrorCodeType.ER_ACCESS_DENIED_ERROR;
            error.message = message.getBytes(CharsetUtil.getCharset(AppContext.getInstance().getFrontendProxy().getCharsetIndex()));
            connection.write(error);
            return;
        }

        //数据库
        if(authPacket.database !=null && authPacket.database.length() > 0 ){

            if(!authPacket.database.equalsIgnoreCase(AppContext.getInstance().getFrontendProxy().getSchema())){
                String message = String.format("Access denied for user ' %s' for database ' %s '", authPacket.user,authPacket.database);
                ErrorPacket error = new ErrorPacket();
                error.packetId = authPacket.packetId + 1;
                error.errNo = ErrorCodeType.ER_DBACCESS_DENIED_ERROR;
                error.message = message.getBytes(CharsetUtil.getCharset(AppContext.getInstance().getFrontendProxy().getCharsetIndex()));
                connection.write(error);
                return;
            }
        }

        //设置数据库
        connection.setSchema(authPacket.database);

        connection.getChannel().writeAndFlush(OkPacket.getOk( ++ bin.packetId )).addListener(new GenericFutureListener() {
            @Override
            public void operationComplete(Future future) throws Exception {
                connection.setAuth(true);
                connection.setSchema(authPacket.database);
                connection.setUser(authPacket.user);
            }
        });
    }


    /**
     * 校验用户名
     * @param authPacket
     * @param connection
     */
    private static boolean _checkUser(AuthPacket authPacket,FrontendConnection connection){
        for(Account account : AppContext.getInstance().getFrontendProxy().getUsers()){
            if(authPacket.user .equalsIgnoreCase(account.getUser())){
                connection.setAccount(account);
                return true;
            }
        }
        return false;
    }


    /**
     * 校验密码
     *
     * @param authPacket
     * @param seed
     */
    private static boolean _checkPassword(AuthPacket authPacket, byte[] seed) {

        String pass = null;

        for(Account account : AppContext.getInstance().getFrontendProxy().getUsers()){
            if(authPacket.user .equalsIgnoreCase(account.getUser())){
                pass = account.getPassword();
                break;
            }
        }
        byte[] password = authPacket.password;
        //校验空密码/不用密码
        if (pass == null || pass.length() == 0) {
            if (password == null || password.length == 0) {
                return true;
            } else {
                return false;
            }
        }
        if (password == null || password.length == 0) {
            return false;
        }

        byte[] encryptPass = null;
        try {
            encryptPass = SecurityUtil.scramble411(pass.getBytes(), seed);
        } catch (NoSuchAlgorithmException e) {
            return false;
        }
        if (encryptPass != null && (encryptPass.length == password.length)) {
            int i = encryptPass.length;
            while (i-- != 0) {
                if (encryptPass[i] != password[i]) {
                    return false;
                }
            }
        } else {
            return false;
        }
        return true;
    }






}
